Skip to content

Add support for conv1d, Phase 3#2941

Open
mdgrs wants to merge 2 commits into
google:mainfrom
belfortlabs:mdgrs/conv1dPhase3
Open

Add support for conv1d, Phase 3#2941
mdgrs wants to merge 2 commits into
google:mainfrom
belfortlabs:mdgrs/conv1dPhase3

Conversation

@mdgrs
Copy link
Copy Markdown
Collaborator

@mdgrs mdgrs commented May 12, 2026

This builds on the previous work #2919 (and that previous work is the first commit)

It uses the work on Conv2d_nchw_fchw as a blueprint to add support for Con1d_ncw_fcw. In my tests, this is the linalg operation that my pytorch conv1d layer gets lowered to.

It makes conv1d turn green for #2923

@mdgrs mdgrs requested review from asraa and j2kun May 12, 2026 08:39
// h' = hi * g + (ci % g**2) // g
// w' = wi * g + (ci % g)
// 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c'
// FIXME why are these interchanged???
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like h and w were exchanged in the definitions of hOut and wOut

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asraa to comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no! @mdgrs could you link this into an issue? i don't want to block this but I'm also noticing an error in flattening the final output in the ISL string.

Copy link
Copy Markdown
Collaborator

@asraa asraa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I'll take a look - could you try rebasing? Something strange happened in the diff where it appears it deletes some recent changes

@mdgrs mdgrs force-pushed the mdgrs/conv1dPhase3 branch from b9b4a7b to 6fbdb50 Compare May 12, 2026 16:29
Copy link
Copy Markdown
Collaborator

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM.

Comment thread lib/Kernel/Kernel.cpp Outdated
// h' = hi * g + (ci % g**2) // g
// w' = wi * g + (ci % g)
// 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c'
// FIXME why are these interchanged???
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asraa to comment

int64_t filterSize = filterType.getDimSize(2);
int64_t outputW = (dataSize + 2 * padding - filterSize) / stride + 1;

auto rowInterchangeRelation = get1dConvRowInterchangeRelation(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check: you want this because the cross-channel filters are stacked a certain way that, without the row interchange, requires more diagonals.

I admit it took us a while to derive the conv2d version, and I am struggling to remember how this interacted with the linalg canonicalization lowering it to a loop (which is also present in the PR for conv1d). @asraa didn't we end up having a problem with converting it to a loop, and didn't we ultimately resort to a dedicated lowering for the Conv2DNchwFchwOp as a whole?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test to highlights how the number of diagonals is reduced with the interchange

For the linalg canonicalization, I saw that currently we have for loops that loop over the extra dimensions of Conv2dNchwFchwOp and rewrite it as a series of Conv2dOp. I did the same for Conv1dNcwFcw.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to lower to a loop anymore generally because preserving the multi-channel input/output let's us arrange each kernel into blocks of a larger toeplitz matrix - so I don't expect that we'll be wanting to rewrite as a loop of single convolutions often.

The other "problem" with the loop approach is that the conv_nd operations don't support strides so we can't use that to represent pooling / downsampling!

fix annotate overwrite
@mdgrs mdgrs force-pushed the mdgrs/conv1dPhase3 branch from 6fbdb50 to 31680e5 Compare May 13, 2026 09:26
};

// Lower linalg.conv_1d_ncw_fcw to a loop of linalg.conv_1d operations.
struct LowerConv1DNcwFcw
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comment earlier, but I think we might not actualyl want this since you already have the ncw_fcw kernels written. in the end to end pipeline, this pass runs very early in the pipeline and that means later passes will only see the linalg.conv_1d ops

// h' = hi * g + (ci % g**2) // g
// w' = wi * g + (ci % g)
// 3. Flatten (gW, gH, C) into idx_out = (c * g * h) * w' + (c) * h' + c'
// FIXME why are these interchanged???
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no! @mdgrs could you link this into an issue? i don't want to block this but I'm also noticing an error in flattening the final output in the ISL string.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants